# -*- coding: utf-8 -*-
"""ICLR_work_due_Sept_21st.ipynb
"""

from google.colab import drive
drive.mount("/content/gdrive")
path = "gdrive/MyDrive/AGNNs/"

#MAIN PROGRAM BEGINS HERE

import json
from random import sample, choice, seed
import gc #garbage collection
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

gc.collect() #free unused memory at initialization

video_list = sample([str(10000 + i) for i in range(5000)],100) #5 random videos from the dataset
propnet_path = path + "with_edge_supervision_old/" #from video_id, 10000, 14999
question_info_path = path+'mc_1000q_4000c_val_new.json'

#==============================================
#START: NEURAL NET hyperparameters (METHOD 1: SYM_GEN)
#==============================================
e_size = 1536 #embedding size [96,384,768]
c_size = None #max context size
training_epochs = 100 #no. of gradient steps [100,500,1000]
early_stopping = False #TODO: code this later
n_blocks = 1 #no. of self attention blocks [2,6,12,24?]
n_layers = 0 #no. of hidden layers in feed-forward network layers [2,3]
h_size = 1536 #hidden layer size for feed-forward network layers [100, 1000]
n_heads = 12 #no. of attention heads per attention block [12]
n_tokens = None #no. of tokens
max_n = 10 #max no. of knowledge graph nodes
device = ('cuda' if torch.cuda.is_available() else 'cpu')
tk = None #tokenizer as a global variable
batch_size = 8
#==============================================
#END: NEURAL NET hyperparameters (METHOD 1: SYM_GEN)
#==============================================

dataset = {}

def get_video_data_and_knowledge():

  videos_info = dict()
  c_sizes = []

  for video_file in video_list:
    formatted_video_file = 'sim_'+video_file+'.json'
    f = open(propnet_path+formatted_video_file)
    videos_info[video_file] = json.load(f)
    f.close()

  question_info_path = path+'mc_1000q_4000c_val_new.json'
  video_questions_and_answers = dict()
  f = open(question_info_path)
  all_video_info = json.load(f)
  f.close()

  for video in videos_info:

    #==========================================
    #get video collision order knowledge graph
    #==========================================
    video_info = videos_info[video]
    video_collisions = video_info['predictions'][0]['collisions']
    n_range = range(len(video_collisions))
    COG_for_video = []

    for i in n_range[:-1]:
      collision = video_collisions[i]
      next_collision = video_collisions[i+1]

      #information about first collision
      coll_objects = collision['objects']
      obj1, obj2 = coll_objects[0], coll_objects[1]
      collision_description = '_'.join(list(obj1.values()))+'_collides_with_'+'_'.join(list(obj2.values()))

      #information about next collision
      coll_objects = next_collision['objects']
      obj1, obj2 = coll_objects[0], coll_objects[1]
      next_collision_description = '_'.join(list(obj1.values()))+'_collides_with_'+'_'.join(list(obj2.values()))

      COG_for_video.append([collision_description,next_collision_description])

    #=====================================================
    #finished getting video collision order knowledge graph
    #=====================================================

    #==========================================
    #get video questions and question programs
    #==========================================
    video_questions = all_video_info[video]['questions']
    for question in video_questions:
      q_str = question['question']
      q_prg = question['program_gt']
      c_sizes.append(len(q_str.split(' '))+len(q_prg))
      dataset[(q_str,tuple([tuple(x) for x in COG_for_video]))] = q_prg

    #======================================================
    #finished getting video questions and question programs
    #======================================================

    gc.collect() #free unused memory before returning

  global c_size; c_size = max(c_sizes)

class Tokenizer(object):

  def __init__(self,
               dataset):

    self.tokens = []
    self.token_idx = {}

    for data_point in dataset:

      q_str = data_point[0] #g = list(list(x) for x in data_point[1])
      q_prg = dataset[data_point]

      self.tokens += q_str.split(' ')
      self.tokens += q_prg

    self.tokens = list(set(self.tokens)) #to remove redundant tokens
    self.n_tokens = len(self.tokens); n_range = range(self.n_tokens)

    for n in n_range:
      token = self.tokens[n]
      self.token_idx[token] = n

    self.tokens.append('END')
    self.token_idx['END'] = self.n_tokens
    self.n_tokens = len(self.tokens)

    global n_tokens; n_tokens = self.n_tokens

    gc.collect()

  def encode(self,inp):

    if type(inp) == str: #encoding question
      inp_tokens = inp.split(' '); gc.collect()
      return [self.token_idx[token] for token in inp_tokens]

    else: #encoding question program
      inp_tokens = inp + ['END']; gc.collect()
      return [self.token_idx[token] for token in inp_tokens]

  def decode(self,encoding):

    return [self.tokens[x] for x in encoding]

class MHA(nn.Module):

  def __init__(self):

    super().__init__()
    self.query = nn.Linear(e_size,e_size,bias=False)
    self.key = nn.Linear(e_size,e_size,bias=False)
    self.value = nn.Linear(e_size,e_size,bias=False)
    self.multihead_atttn = nn.MultiheadAttention(e_size,n_heads)

  def forward(self,inp):

    Q, K, V = self.query(inp), self.key(inp), self.value(inp)
    x, attn_weights = self.multihead_atttn(Q, K, V)
    gc.collect()
    return x

class DataOps(object):

  @staticmethod
  def get_batch():

    batch = {}

    batch_dataset_keys = sample(list(dataset.keys()),1)
    for key in batch_dataset_keys:
      batch[key] = dataset[key]

    gc.collect(); return batch

  @staticmethod
  def process_graph(graph):

    graph = list(list(edge) for edge in graph)
    graph_nodes = list(set(sum([edge for edge in graph],[])))
    graph_node_idx = {}
    n_graph_nodes = len(graph_nodes)
    n_range = range(n_graph_nodes)
    for n in n_range:
      graph_node = graph_nodes[n]
      graph_node_idx[graph_node] = n
    graph_matrix = torch.zeros(n_graph_nodes,n_graph_nodes)

    for edge in graph:
      v1_idx = graph_node_idx[edge[0]]
      v2_idx = graph_node_idx[edge[1]]
      graph_matrix[v1_idx,v2_idx] = 1.0

    return graph_matrix, n_graph_nodes

  @staticmethod
  def process_batch(batch): #one batch = one (key,value) pair from the dataset

    #batch = {(question-->str,graph-->tuple):program-->list}

    global tk

    X, Y = [],[]
    batch_key = list(batch.keys())[0]
    x, y = batch_key[0], batch[batch_key]
    x_encoding = tk.encode(x)
    y_encoding = tk.encode(y)
    batch_encoding = x_encoding + y_encoding
    batch_size = len(batch_encoding)
    for t_idx in range(batch_size-1):
      x_sub = batch_encoding[:t_idx+1]
      y_sub = batch_encoding[t_idx+1]
      X += [x_sub]; Y += [y_sub]

    gc.collect()
    return X, Y

class generator(nn.Module):

    def __init__(self):

      super().__init__()
      self.embeddings = nn.Embedding(n_tokens,e_size)
      self.pos_embeddings = nn.Embedding(c_size,e_size)
      self.node_embeddings = nn.Embedding(max_n,e_size)
      nblocks_range = range(n_blocks)
      self.norm = self.LayerNorm(e_size)
      self.MHA_blocks = nn.ModuleList([MHA() for _ in nblocks_range])
      self.ffn_inp = nn.Linear(e_size,h_size)
      self.ffns = nn.ModuleList([nn.Linear(h_size,h_size) for _ in range(n_layers)])
      self.ffn_out = nn.Linear(h_size,e_size)
      self.head = nn.Linear(e_size,n_tokens)

    def graph_create(self,
                     n_es,
                     rep,
                     G = None):

      n_es = torch.row_stack([e + rep for e in n_es]) #integrate rep of question with kg node embeddings
      n = len(n_es); n_range = range(n)
      kg_a = torch.zeros(n,n)
      eps = 1e-7 #log domain correction because log(0) is not allowed
      for i in n_range:
        for j in n_range: #KL(i,j)
          if i == j:
            continue

          x = torch.sigmoid(n_es[i])+eps #torch.sigmoid to convert negative number to numbers between 0 and 1, because log cannot be taken for negative numbers
          y = torch.sigmoid(n_es[j])+eps
          #bregman divergence --> generalization of KL divergence to generic measures over sets instead of probability measures over sets.
          bd_x_y = torch.sigmoid(torch.sum(x*torch.log(x))-torch.sum(x*torch.log(y))) #<-- 0, 1
          kg_a[i,j] = bd_x_y

      gc.collect()
      if G is None:
        return kg_a
      else:
        return kg_a * G #mask

    def forward(self,X):

      logits = []

      for x in X:
        nx_tokens = len(x)
        x_idxs = torch.tensor(x)
        e_x = self.embeddings(x_idxs)
        p_e = self.pos_embeddings(torch.arange(nx_tokens))
        e_x += p_e
        for MHA_block in self.MHA_blocks:
          e_x = MHA_block(self.norm(e_x))
        e_x = F.leaky_relu(self.norm(self.ffn_inp(e_x)))
        for ffn in self.ffns:
          e_x = F.leaky_relu(self.norm(ffn(e_x)))
        e_x = F.leaky_relu(self.norm(self.ffn_out(e_x)))
        rep = torch.mean(e_x,dim=0)
        e_x = self.head(e_x)
        logit_set = e_x[-1]
        logits.append(logit_set)

      gc.collect()
      return torch.row_stack(logits), rep

    def train(self,
              method='sym'):

      epoch_idxs = range(training_epochs)
      optimizer = torch.optim.AdamW(self.parameters())
      CE = nn.CrossEntropyLoss()

      for epoch in epoch_idxs: #500, 20*500 epochs

        b_range = range(batch_size)
        total_loss = 0.0

        try:

          for i in tqdm(b_range):
            batch = DataOps.get_batch()
            key = list(batch.keys())[0]; graph = key[1]
            X, Y = DataOps.process_batch(batch)
            G, nG = DataOps.process_graph(graph)

            logits, rep = self(X)
            n_es = self.node_embeddings(torch.arange(nG))
            kg_a = None
            if method == 'sym':
              n_es = torch.row_stack([e + rep for e in n_es])
              kg_a = torch.sigmoid(n_es @ n_es.t())
            if method == 'asym':
              kg_a = self.graph_create(n_es,rep)
            if method == 'cbf':
              kg_a = self.graph_create(n_es,rep,G = G)
            kg_loss = torch.sum(torch.pow(G-kg_a,2)) #frobenius norm
            targets = []
            for y in Y:

              target = [0.0 for _ in range(n_tokens)]; target[y] = 1.0 #create one-hot encoding
              targets.append(target)

            targets = torch.tensor(targets) #convert targets list to torch tensor
            gen_loss = CE(logits,targets) #cross entropy loss
            loss = gen_loss + kg_loss
            total_loss += loss

            #print ("gen loss",gen_loss.item())
            #print ("kg loss",kg_loss.item())
        except:
          continue
        total_loss /= batch_size
        print ("total loss",loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        del loss; gc.collect() #freeup memory by deleting computation graph for back-prop after each epoch since next epoch does not need it
      gc.collect() #free any remaining unused memory

def main(method='sym'):

  global tk

  get_video_data_and_knowledge()
  tk = Tokenizer(dataset)
  model = generator()
  model.train(method='sym') #method = 'sym' or 'asym' or 'cbf'

TEST = True
if TEST:
  main()
